# Author: Kunal Sangani, ksangani@g.harvard.edu

import support_functions as sp
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from scipy.optimize import fsolve
import tikzplotlib
from scipy.interpolate import interp1d
from scipy.optimize import minimize
import matplotlib.pylab as pl
from numpy.linalg import inv
import klenow_willis as kw

# Quarterly
# From Gali (pg. 68): Moderately persistent shock of 0.5
persistence_w = 0.7
persistence_m = 0.5
# From Gali (pg. 67): theta=3/4 --> avg price duration of 4 quarters
pi = 0.5
# From Gali (pg. 67): beta=0.99 --> annualized return of 4%
beta = 0.99
# From Gali (pg. 67): Frisch elasticity of 0.2
frisch = 0.2
inc_elast = 0.2#0.2
# Monetary policy rule constants, from Gali (pg. ??)
phi_y = 0.5 / 4
phi_pi = 1.5
# From Gali (pg.67) Eta = 4
eta = 4

# 25 basis point interest rate shock
MONETARY_SHOCK_SIZE = 0.25
MONEY_SUPPLY_SHOCK_SIZE = 0.25

# Constant used a lot in equations
pibeta = (pi - beta*pi*(1-pi))/(1-pi)

# Once the path of the price index is known, all other variables can be calculated
def compute_other_impulse_responses(delta, E_elast, E_rho, E_elastrho, E_mu_elast, E_mu_elastrho, periods=50):
	# WEIGHTS_MU * mu + CONST_MU = WEIGHTS_D * delta + CONST_D
	WEIGHTS_D = np.zeros((periods, periods))
	WEIGHTS_MU = np.zeros((periods, periods))
	CONST_D = np.zeros(periods)
	CONST_MU = np.zeros(periods)
	# Since both boundary conditions on mu and delta are zero, no const needed
	# Set up matrices according to diff eq from kimball_sticky_corrected
	for i in range(0,periods):
		WEIGHTS_MU[i][i] = 1 + beta + pibeta
		WEIGHTS_D[i][i] = 1 + beta + pibeta*(1 + (E_elastrho-E_elast*E_rho)/E_elast)
		if i > 0:
			WEIGHTS_MU[i][i-1] = -1
			WEIGHTS_D[i][i-1] = -1
		if i + 1 < periods:
			WEIGHTS_MU[i][i+1] = -beta
			WEIGHTS_D[i][i+1] = -beta
	mu = np.linalg.inv(WEIGHTS_MU) @ (WEIGHTS_D @ delta + CONST_D - CONST_MU)
	# Now for A
	WEIGHTS_D = np.zeros((periods, periods))
	WEIGHTS_A = np.zeros((periods, periods))
	CONST_D = np.zeros(periods)
	CONST_A = np.zeros(periods)
	# Boundary conditions for A
	for i in range(0,periods):
		WEIGHTS_A[i][i] = 1 + beta + pibeta
		WEIGHTS_D[i][i] = -pibeta * (E_elast*E_mu_elastrho - E_mu_elast*E_elastrho) / E_elast
		if i > 0:
			WEIGHTS_A[i][i-1] = -1
		if i + 1 < periods:
			WEIGHTS_A[i][i+1] = -beta
	A = np.linalg.inv(WEIGHTS_A) @ (WEIGHTS_D @ delta + CONST_D - CONST_A)
	# Output, labor, and labor share
	Y = 1/(1+inc_elast) * (A - frisch * mu)
	Lambda = -mu - A
	L = -inc_elast/(1+inc_elast)*A - frisch/(1+inc_elast)*mu
	return mu,A,Y,Lambda,L

def money_supply_impulse_responses_moneydemand(E_elast, E_rho, E_elastrho, E_mu_elast, E_mu_elastrho, periods=50):
	dM = np.zeros(periods)
	M = np.zeros(periods)
	dM[0] = 1
	M[0] = dM[0] * MONEY_SUPPLY_SHOCK_SIZE
	for i in range(1,periods):
		dM[i] = dM[i-1] * persistence_m
		M[i] = M[i-1] + dM[i]*MONEY_SUPPLY_SHOCK_SIZE
	print('dM: ', dM[:15])
	print('M: ', M[:15])
	inv_mu_bar = (E_elastrho - E_rho)/(E_mu_elastrho)
	print('inv_mu_bar:', inv_mu_bar)
	constA = E_mu_elast*E_elastrho - E_elast*E_mu_elastrho
	# Solve: WEIGHTS_M_MU @ MU + WEIGHTS_M_A @ A = WEIGHTS_M_M @ M
	# WEIGHTS_A_A @ A = WEIGHTS_A_DELTA @ DELTA
	# WEIGHTS_MU_MU @ MU = WEIGHTS_MU_DELTA @ DELTA
	# Solve: [WEIGHTS_M_MU @ inv(WEIGHTS_MU_MU) @ WEIGHTS_MU_DELTA + WEIGHTS_M_A @ inv(WEIGHTS_A_A) @ WEIGHTS_A_DELTA] @ DELTA = WEIGHTS_M_M @ M
	WEIGHTS_M_M = np.zeros((periods, periods))
	WEIGHTS_M_A = np.zeros((periods, periods))
	WEIGHTS_M_MU = np.zeros((periods, periods))
	WEIGHTS_A_A = np.zeros((periods, periods))
	WEIGHTS_A_DELTA = np.zeros((periods, periods))
	WEIGHTS_MU_MU = np.zeros((periods, periods))
	WEIGHTS_MU_DELTA = np.zeros((periods, periods))
	# Set up matrices according to diff eq from kimball_sticky_corrected
	for i in range(0,periods):
		WEIGHTS_M_M[i][i] = 1 + beta
		WEIGHTS_M_A[i][i] = (1+eta*inc_elast)/(1+inc_elast)*(1+beta) + pibeta*(1-E_rho)*inv_mu_bar*(1+eta) - eta*inc_elast/(1+inc_elast)*(-1)
		WEIGHTS_M_MU[i][i] = -frisch*(1+eta*inc_elast)/(1+inc_elast)*(1+beta) - pibeta*E_rho*(1+eta) + frisch*eta*inc_elast/(1+inc_elast)*(-1)
		WEIGHTS_A_A[i][i] = 1+beta+pibeta
		WEIGHTS_A_DELTA[i][i] = pibeta*(constA)/E_elast
		WEIGHTS_MU_MU[i][i] = 1 + beta + pibeta
		WEIGHTS_MU_DELTA[i][i] = 1 + beta + pibeta*(1 + (E_elastrho - E_elast*E_rho)/E_elast)
		if i > 0:
			WEIGHTS_M_M[i][i-1] = -1
			WEIGHTS_M_A[i][i-1] = -(1+eta*inc_elast)/(1+inc_elast)
			WEIGHTS_M_MU[i][i-1] = frisch*(1+eta*inc_elast)/(1+inc_elast)
			WEIGHTS_A_A[i][i-1] = -1
			WEIGHTS_MU_MU[i][i-1] = -1
			WEIGHTS_MU_DELTA[i][i-1] = -1
		if i + 1 < periods:
			WEIGHTS_M_M[i][i+1] = -beta
			WEIGHTS_M_A[i][i+1] = -beta*(1+eta*inc_elast)/(1+inc_elast) - eta*inc_elast/(1+inc_elast)*(1 + beta) + pibeta*(1-E_rho)*inv_mu_bar*(-eta)
			WEIGHTS_M_MU[i][i+1] = beta*frisch*(1+eta*inc_elast)/(1+inc_elast) + frisch*eta*inc_elast/(1+inc_elast)*(1+beta) - pibeta*E_rho*(-eta)
			WEIGHTS_A_A[i][i+1] = -beta
			WEIGHTS_MU_MU[i][i+1] = -beta
			WEIGHTS_MU_DELTA[i][i+1] = -beta
		if i + 2 < periods:
			WEIGHTS_M_A[i][i+2] = -eta*inc_elast/(1+inc_elast)*(-beta)
			WEIGHTS_M_MU[i][i+2] = frisch*eta*inc_elast/(1+inc_elast)*(-beta)
	delta = inv(WEIGHTS_M_MU @ inv(WEIGHTS_MU_MU) @ WEIGHTS_MU_DELTA + WEIGHTS_M_A @ inv(WEIGHTS_A_A) @ WEIGHTS_A_DELTA) @ WEIGHTS_M_M @ M
	mu2 = inv(WEIGHTS_MU_MU) @ WEIGHTS_MU_DELTA @ delta
	A2 = inv(WEIGHTS_A_A) @ WEIGHTS_A_DELTA @ delta
	# To solve for w:
	# Solve: WEIGHTS_D * delta = WEIGHTS_W * w 
	WEIGHTS_D = np.zeros((periods, periods))
	WEIGHTS_W = np.zeros((periods, periods))
	# Set up matrices according to diff eq from kimball_sticky_corrected
	for i in range(0,periods):
		WEIGHTS_D[i][i] = 1 + beta + pibeta * E_elastrho/E_elast
		WEIGHTS_W[i][i] = -(1 + beta)
		if i > 0:
			WEIGHTS_D[i][i-1] = -1
			WEIGHTS_W[i][i-1] = 1
		if i + 1 < periods:
			WEIGHTS_D[i][i+1] = -beta
			WEIGHTS_W[i][i+1] = beta
	w = np.linalg.inv(WEIGHTS_W) @ (WEIGHTS_D @ delta)
	mu,A,Y,Lambda,L = compute_other_impulse_responses(delta, E_elast, E_rho, E_elastrho, E_mu_elast, E_mu_elastrho, periods=periods)
	# Output, price, labor, labor share, and nominal interest rate
	p = mu + w 
	print('Adiff: ', (np.abs(A2-A)).max())
	print('mudiff: ', (np.abs(mu2-mu)).max())
	print('Mdiff: ', (np.abs(p+Y-M)).max())
	i_nom1 = inc_elast/frisch*(-Y) + (-p)
	for ind in range(periods-1):
		i_nom1[ind] = inc_elast/frisch*(Y[ind+1]-Y[ind]) + (p[ind+1]-p[ind])
	print(i_nom1[:15])
	dpi = p + 0
	for ind in range(1, periods):
		dpi[ind] = p[ind] - p[ind-1]
	r = i_nom1.copy()
	for ind in range(periods-1):
		r[ind] -= dpi[ind+1]
	return M,delta,mu,A,Y,Lambda,L,w,i_nom1,r,p,dpi

def money_supply_impulse_responses_cia(E_elast, E_rho, E_elastrho, E_mu_elast, E_mu_elastrho, periods=50):
	dM = np.zeros(periods)
	M = np.zeros(periods)
	dM[0] = 1
	M[0] = dM[0] * MONEY_SUPPLY_SHOCK_SIZE
	for i in range(1,periods):
		dM[i] = dM[i-1] * persistence_m
		M[i] = M[i-1] + dM[i]*MONEY_SUPPLY_SHOCK_SIZE
	inv_mu_bar = (E_elastrho - E_rho)/(E_mu_elastrho)
	constA = E_mu_elast*E_elastrho - E_elast*E_mu_elastrho
	# Solve: WEIGHTS_M_MU @ MU + WEIGHTS_M_A @ A = WEIGHTS_M_M @ M
	# WEIGHTS_A_A @ A = WEIGHTS_A_DELTA @ DELTA
	# WEIGHTS_MU_MU @ MU = WEIGHTS_MU_DELTA @ DELTA
	# Solve: [WEIGHTS_M_MU @ inv(WEIGHTS_MU_MU) @ WEIGHTS_MU_DELTA + WEIGHTS_M_A @ inv(WEIGHTS_A_A) @ WEIGHTS_A_DELTA] @ DELTA = WEIGHTS_M_M @ M
	WEIGHTS_M_M = np.zeros((periods, periods))
	WEIGHTS_M_A = np.zeros((periods, periods))
	WEIGHTS_M_MU = np.zeros((periods, periods))
	WEIGHTS_A_A = np.zeros((periods, periods))
	WEIGHTS_A_DELTA = np.zeros((periods, periods))
	WEIGHTS_MU_MU = np.zeros((periods, periods))
	WEIGHTS_MU_DELTA = np.zeros((periods, periods))
	# Set up matrices according to diff eq from kimball_sticky_corrected
	for i in range(0,periods):
		WEIGHTS_M_M[i][i] = 1 + beta
		WEIGHTS_M_A[i][i] = (1)/(1+inc_elast)*(1+beta) + pibeta*(1-E_rho)*inv_mu_bar*(1) 
		WEIGHTS_M_MU[i][i] = -frisch*(1)/(1+inc_elast)*(1+beta) - pibeta*E_rho*(1) 
		WEIGHTS_A_A[i][i] = 1+beta+pibeta
		WEIGHTS_A_DELTA[i][i] = pibeta*(constA)/E_elast
		WEIGHTS_MU_MU[i][i] = 1 + beta + pibeta
		WEIGHTS_MU_DELTA[i][i] = 1 + beta + pibeta*(1 + (E_elastrho - E_elast*E_rho)/E_elast)
		if i > 0:
			WEIGHTS_M_M[i][i-1] = -1
			WEIGHTS_M_A[i][i-1] = -(1)/(1+inc_elast)
			WEIGHTS_M_MU[i][i-1] = frisch*(1)/(1+inc_elast)
			WEIGHTS_A_A[i][i-1] = -1
			WEIGHTS_MU_MU[i][i-1] = -1
			WEIGHTS_MU_DELTA[i][i-1] = -1
		if i + 1 < periods:
			WEIGHTS_M_M[i][i+1] = -beta
			WEIGHTS_M_A[i][i+1] = -beta*(1)/(1+inc_elast) 
			WEIGHTS_M_MU[i][i+1] = beta*frisch*(1)/(1+inc_elast) 
			WEIGHTS_A_A[i][i+1] = -beta
			WEIGHTS_MU_MU[i][i+1] = -beta
			WEIGHTS_MU_DELTA[i][i+1] = -beta
		# if i + 2 < periods:
		# 	WEIGHTS_M_A[i][i+2] = -eta*inc_elast/(1+inc_elast)*(-beta)
		# 	WEIGHTS_M_MU[i][i+2] = frisch*eta*inc_elast/(1+inc_elast)*(-beta)
	delta = inv(WEIGHTS_M_MU @ inv(WEIGHTS_MU_MU) @ WEIGHTS_MU_DELTA + WEIGHTS_M_A @ inv(WEIGHTS_A_A) @ WEIGHTS_A_DELTA) @ WEIGHTS_M_M @ M
	mu2 = inv(WEIGHTS_MU_MU) @ WEIGHTS_MU_DELTA @ delta
	A2 = inv(WEIGHTS_A_A) @ WEIGHTS_A_DELTA @ delta
	try_again = True
	if try_again:
		# We'll set up a matrix of WEIGHTS_D * (delta, A, mu) = WEIGHTS_M * M
		WEIGHTS_D = np.zeros((3*periods, 3*periods))
		WEIGHTS_M = np.zeros((periods, periods))
		for i in range(periods):
			WEIGHTS_D[i][periods+i] = 1/(1+inc_elast)*(1+beta) + pibeta*(E_elastrho-E_rho)/E_mu_elastrho*(1-E_rho)
			WEIGHTS_D[i][2*periods+i] = -frisch/(1+inc_elast)*(1+beta) - pibeta*E_rho
			WEIGHTS_M[i][i] = 1 + beta
			if i >= 1:
				WEIGHTS_D[i][periods+i-1] = 1/(1+inc_elast)*(-1)
				WEIGHTS_D[i][2*periods+i-1] = frisch/(1+inc_elast)
				WEIGHTS_M[i][i-1] = -1
			if i <= periods - 2:
				WEIGHTS_D[i][periods+i+1] = 1/(1+inc_elast)*(-beta)
				WEIGHTS_D[i][2*periods+i+1] = frisch/(1+inc_elast)*beta
				WEIGHTS_M[i][i+1] = -beta
			# Submatrices
			if (i >= 1) and (i <= periods - 1):
				WEIGHTS_D[periods+i][periods+i-1] = -1
			if i <= periods - 1:
				WEIGHTS_D[periods+i][periods+i] = 1 + beta + pibeta
			if i <= periods - 2:
				WEIGHTS_D[periods+i][periods+i+1] = -beta
			if i <= periods - 1:
				WEIGHTS_D[periods+i][i] = -(pibeta * constA / E_elast)
			if i <= periods - 1:
				WEIGHTS_D[2*periods+i][2*periods+i] = 1 + beta + pibeta
				WEIGHTS_D[2*periods+i][i] = -(1 + beta + pibeta*(1 + (E_elastrho-E_elast*E_rho)/E_elast))
			if (i >= 1) and (i <= periods - 1):
				WEIGHTS_D[2*periods+i][2*periods+i-1] = -1
				WEIGHTS_D[2*periods+i][i-1] = -(-1)
			if i <= periods - 2:
				WEIGHTS_D[2*periods+i][2*periods+i+1] = -beta
				WEIGHTS_D[2*periods+i][i+1] = -(-beta)
		RHS = WEIGHTS_M @ M
		RHS = np.pad(RHS, (0, 2*periods), 'constant')
		SOL = np.linalg.inv(WEIGHTS_D) @ RHS
		delta = SOL[:periods]
		A2 = SOL[periods : 2*periods]
		mu2 = SOL[2*periods:]
	# To solve for w:
	# Solve: WEIGHTS_D * delta = WEIGHTS_W * w 
	WEIGHTS_D = np.zeros((periods, periods))
	WEIGHTS_W = np.zeros((periods, periods))
	# Set up matrices according to diff eq from kimball_sticky_corrected
	for i in range(0,periods):
		WEIGHTS_D[i][i] = 1 + beta + pibeta * E_elastrho/E_elast
		WEIGHTS_W[i][i] = -(1 + beta)
		if i > 0:
			WEIGHTS_D[i][i-1] = -1
			WEIGHTS_W[i][i-1] = 1
		if i + 1 < periods:
			WEIGHTS_D[i][i+1] = -beta
			WEIGHTS_W[i][i+1] = beta
	w = np.linalg.inv(WEIGHTS_W) @ (WEIGHTS_D @ delta)
	mu,A,Y,Lambda,L = compute_other_impulse_responses(delta, E_elast, E_rho, E_elastrho, E_mu_elast, E_mu_elastrho, periods=periods)
	# Output, price, labor, labor share, and nominal interest rate
	p = mu + w 
	# print('Adiff: ', (np.abs(A2-A)).max())
	# print('mudiff: ', (np.abs(mu2-mu)).max())
	# print('Mdiff: ', (np.abs(p+Y-M)).max())
	i_nom1 = inc_elast/frisch*(-Y) + (-p)
	for ind in range(periods-1):
		i_nom1[ind] = inc_elast/frisch*(Y[ind+1]-Y[ind]) + (p[ind+1]-p[ind])
	dpi = p + 0
	for ind in range(1, periods):
		dpi[ind] = p[ind] - p[ind-1]
	r = i_nom1.copy()
	for ind in range(periods-1):
		r[ind] -= dpi[ind+1]
	return M,delta,mu,A,Y,Lambda,L,w,i_nom1,r,p,dpi



def monetary_policy_impulse_responses(E_elast, E_rho, E_elastrho, E_mu_elast, E_mu_elastrho, periods=50):
	v = np.zeros(periods)
	# Set v vector as MIT shock
	v[0] = MONETARY_SHOCK_SIZE #0.25
	#v[0] = 1
	for i in range(1,periods):
		v[i] = v[i-1] * persistence_w
	constA = E_mu_elast*E_elastrho - E_elast*E_mu_elastrho
	constB = inc_elast/(1+inc_elast)
	coeff_ddelta = -constB
	coeff_delta0 = pibeta*(constB/frisch*constA/E_elast - constB*E_elastrho/E_elast + (1-constB)*(1-E_rho))
	coeff_deltan1 = -phi_pi*pibeta*(1-E_rho)
	coeff_A0 = (-constB/frisch*pibeta + beta*phi_y/(1+inc_elast))
	coeff_mu0 = (-(1 - constB)*pibeta - beta*frisch*phi_y/(1+inc_elast))
	coeff_An1 = -phi_y/(1+inc_elast)
	coeff_mun1 = frisch*phi_y/(1+inc_elast) + phi_pi*pibeta
	# We'll set up a matrix of WEIGHTS_D * (delta, A, mu) = WEIGHTS_V * v
	WEIGHTS_D = np.zeros((3*periods, 3*periods))
	WEIGHTS_V = np.zeros((periods, periods))
	for i in range(periods):
		if i <= periods - 1:
			WEIGHTS_D[i][i] = -coeff_ddelta + coeff_deltan1
		if i <= periods - 2:
			WEIGHTS_D[i][i+1] = coeff_delta0 + (1+beta)*coeff_ddelta
		if i <= periods - 3:
			WEIGHTS_D[i][i+2] = -beta*coeff_ddelta
		if i <= periods - 1:
			WEIGHTS_D[i][periods+i] = coeff_An1
		if i <= periods - 2:
			WEIGHTS_D[i][periods+i+1] = coeff_A0
		if i <= periods - 1:
			WEIGHTS_D[i][2*periods+i] = coeff_mun1
		if i <= periods - 2:
			WEIGHTS_D[i][2*periods+i+1] = coeff_mu0
		if i <= periods - 1:
			WEIGHTS_V[i][i] = 1
		if i <= periods - 2:
			WEIGHTS_V[i][i+1] = -beta
		# SUBMATRICES
		if (i >= 1) and (i <= periods - 1):
			WEIGHTS_D[periods+i][periods+i-1] = -1
		if i <= periods - 1:
			WEIGHTS_D[periods+i][periods+i] = 1 + beta + pibeta
		if i <= periods - 2:
			WEIGHTS_D[periods+i][periods+i+1] = -beta
		if i <= periods - 1:
			WEIGHTS_D[periods+i][i] = -(pibeta * constA / E_elast)
		if i <= periods - 1:
			WEIGHTS_D[2*periods+i][2*periods+i] = 1 + beta + pibeta
			WEIGHTS_D[2*periods+i][i] = -(1 + beta + pibeta*(1 + (E_elastrho-E_elast*E_rho)/E_elast))
		if (i >= 1) and (i <= periods - 1):
			WEIGHTS_D[2*periods+i][2*periods+i-1] = -1
			WEIGHTS_D[2*periods+i][i-1] = -(-1)
		if i <= periods - 2:
			WEIGHTS_D[2*periods+i][2*periods+i+1] = -beta
			WEIGHTS_D[2*periods+i][i+1] = -(-beta)
	RHS = WEIGHTS_V @ v
	RHS = np.pad(RHS, (0, 2*periods), 'constant')
	# FULL_MAT = np.zeros((3*periods,3*periods+1))
	# FULL_MAT[:,0:3*periods] = WEIGHTS_D
	# FULL_MAT[:,3*periods] = RHS
	SOL = np.linalg.inv(WEIGHTS_D) @ RHS
	delta = SOL[:periods]
	A2 = SOL[periods : 2*periods]
	mu2 = SOL[2*periods:]
	mu,A,Y,Lambda,L = compute_other_impulse_responses(delta, E_elast, E_rho, E_elastrho, E_mu_elast, E_mu_elastrho, periods=periods)
	# To solve for w:
	# Solve: WEIGHTS_D * delta = WEIGHTS_W * w 
	WEIGHTS_D = np.zeros((periods, periods))
	WEIGHTS_W = np.zeros((periods, periods))
	# Set up matrices according to diff eq from kimball_sticky_corrected
	for i in range(0,periods):
		WEIGHTS_D[i][i] = 1 + beta + pibeta * E_elastrho/E_elast
		WEIGHTS_W[i][i] = -(1 + beta)
		if i > 0:
			WEIGHTS_D[i][i-1] = -1
			WEIGHTS_W[i][i-1] = 1
		if i + 1 < periods:
			WEIGHTS_D[i][i+1] = -beta
			WEIGHTS_W[i][i+1] = beta
	w = np.linalg.inv(WEIGHTS_W) @ (WEIGHTS_D @ delta)
	p = mu + w 
	i_nom1 = inc_elast/frisch*(-Y) + (-p)
	for ind in range(periods-1):
		i_nom1[ind] = inc_elast/frisch*(Y[ind+1]-Y[ind]) + (p[ind+1]-p[ind])
	# i_nom2 = -np.log(beta) + phi_pi*(p) + phi_y*Y + v
	# for ind in range(1,periods):
	# 	i_nom2[ind] = -np.log(beta) + phi_pi*(p[ind]-p[ind-1]) + phi_y*Y[ind] + v[ind]
	dpi = p + 0
	for ind in range(1, periods):
		dpi[ind] = p[ind] - p[ind-1]
	r = i_nom1
	for ind in range(periods-1):
		r[ind] -= dpi[ind+1]
	M = p+Y
	return v,delta,mu,A,Y,Lambda,L,w,i_nom1,r,p,dpi,M

def augment_impulse_res(impulse_res, var_list, shares, markups, passthroughs, elast, periods, t_max, savefileroot=None):
	w = impulse_res[var_list.index('w')]
	p = impulse_res[var_list.index('p')]
	P_agg = impulse_res[var_list.index('delta')] + w
	Y_agg = impulse_res[var_list.index('Y')]
	cum_shares = np.cumsum(shares) / np.sum(shares)
	shares = np.reshape(shares, (-1,1))
	shares = shares / sum(shares)
	z = np.ones(shares.shape)
	passthroughs = np.reshape(passthroughs, (-1,1))
	elast = np.reshape(elast, (-1,1))
	markups = np.reshape(markups, (-1,1))
	n = shares/markups
	w = np.reshape(w, (-1,1))
	p = np.reshape(p, (-1,1))
	P_agg = np.reshape(P_agg, (-1,1))
	WEIGHTS_P = np.zeros((periods, periods))
	for i in range(0,periods):
		WEIGHTS_P[i][i] = 1 + beta + pibeta
		if i > 0:
			WEIGHTS_P[i][i-1] = -1
		if i + 1 < periods:
			WEIGHTS_P[i][i+1] = -beta
	dp_theta = (pibeta * (passthroughs @ w.T + (1-passthroughs) @ P_agg.T)) @ np.linalg.inv(WEIGHTS_P.T)
	dmu_theta = dp_theta - w.T
	### Added code: Compute share of price resetters and size of first price change conditional on price change
	frac_first_reset = np.zeros(len(w))
	share_price_change = np.zeros(len(w))
	reset_prices = np.zeros((len(passthroughs), len(w)))
	avg_price_change = np.zeros((len(passthroughs), len(w)))
	for t in range(len(w)):
		frac_first_reset[t] = (1-pi)**(t) * pi 
		share_price_change[t] = sum(frac_first_reset[:t+1])
		reset_prices_t = (1-beta*(1-pi)) * sum([(beta*(1-pi))**(i-t) * (passthroughs*w[i]+(1-passthroughs)*P_agg[i]) for i in range(t, len(w))])
		reset_prices[:,t] = reset_prices_t.ravel()
		avg_price_change[:,t] = abs((np.diag(frac_first_reset[:t+1]) @ reset_prices[:,:t+1].T).sum(axis=0))
	avg_price_change_df = pd.DataFrame(avg_price_change)
	avg_price_change_df['group'] = [np.floor(i / (len(passthroughs) / 5)) for i in range(len(passthroughs))]
	avg_price_change_df.groupby(['group'])[0].mean()
	full_df = pd.DataFrame(index=avg_price_change_df['group'].value_counts().sort_index().index)
	for i in range(8):
		full_df['share_price_change_' + str(i)] = share_price_change[i]
		full_df['size_price_change_' + str(i)] = avg_price_change_df.groupby(['group'])[i].mean()
	print(full_df.to_latex(float_format='%0.4f'))
	###
	dlambda_theta = np.diag(np.reshape(1-elast,(-2,))) @ dp_theta + elast @ P_agg.T - np.ones(elast.shape) @ p.T
	dn_theta = np.diag(-np.reshape(elast,(-2,))) @ dp_theta + elast @ P_agg.T + Y_agg.T
	cov_mu_dlogmu = [sp.weighted_cov(markups.ravel(), dmu_theta[:,t], shares.ravel()) for t in range(periods)]
	cov_mu_dlogmu = np.array(cov_mu_dlogmu)
	fig = plt.figure()
	plt.plot(range(t_max), cov_mu_dlogmu[:t_max], marker='o', color='blue', fillstyle='none')
	plt.xlabel('Quarters')
	plt.ylabel('$Cov_{\\lambda}(\\mu, d\\log \\mu)$')
	#plt.close()
	cov_mu_dlogcosts = [sp.weighted_cov(markups.ravel(), dn_theta[:,t], shares.ravel()) for t in range(periods)]
	cov_mu_dlogcosts = np.array(cov_mu_dlogcosts)
	fig = plt.figure()
	plt.plot(range(t_max), cov_mu_dlogcosts[:t_max], marker='o', color='blue', fillstyle='none')
	plt.xlabel('Quarters')
	plt.ylabel('$Cov_{\\lambda}(\\mu, d\\log Costs)$')
	## Inv mu covariances
	cov_invmu_dlogmu = [sp.weighted_cov(-1 / markups.ravel(), dmu_theta[:,t], shares.ravel()) for t in range(periods)]
	cov_invmu_dlogmu = np.array(cov_invmu_dlogmu)
	cov_invmu_dlogcosts = [sp.weighted_cov(-1 / markups.ravel(), dn_theta[:,t], shares.ravel()) for t in range(periods)]
	cov_invmu_dlogcosts = np.array(cov_invmu_dlogcosts)
	#plt.close()
	var_logmu = [sp.weighted_cov(np.log(markups.ravel()) + dmu_theta[:,t], np.log(markups.ravel()) + dmu_theta[:,t], shares.ravel()) for t in range(periods)]
	var_logmu = np.array(var_logmu)
	fig = plt.figure()
	plt.plot(range(t_max), var_logmu[:t_max], marker='o', color='blue', fillstyle='none')
	plt.xlabel('Quarters')
	plt.ylabel('$Var_{\\lambda}(\\log \\mu)$')
	#plt.close()
	fig,axs = plt.subplots(1,2)
	axs[0].plot(cov_invmu_dlogmu[:t_max], marker='o', fillstyle='none', color='blue')
	axs[0].set_title('$Cov_{\\lambda}(-1/\\mu, d\\log \\mu)$')
	axs[1].plot(cov_invmu_dlogcosts[:t_max], marker='o', fillstyle='none', color='blue')
	axs[1].set_title('$Cov_{\\lambda}(-1/\\mu, d\\log$ Costs$)$')
	axs[0].ticklabel_format(useOffset=False, style='plain')
	axs[1].ticklabel_format(useOffset=False, style='plain')
	axs[0].get_xaxis()._gridOnMajor = axs[0].get_xaxis()._major_tick_kw['gridOn']
	axs[1].get_xaxis()._gridOnMajor = axs[1].get_xaxis()._major_tick_kw['gridOn']
	axs[0].get_xaxis()._gridOnMinor = axs[0].get_xaxis()._minor_tick_kw['gridOn']
	axs[1].get_xaxis()._gridOnMinor = axs[1].get_xaxis()._minor_tick_kw['gridOn']
	axs[0].get_yaxis()._gridOnMajor = axs[0].get_yaxis()._major_tick_kw['gridOn']
	axs[1].get_yaxis()._gridOnMajor = axs[1].get_yaxis()._major_tick_kw['gridOn']
	axs[0].get_yaxis()._gridOnMinor = axs[0].get_yaxis()._minor_tick_kw['gridOn']
	axs[1].get_yaxis()._gridOnMinor = axs[1].get_yaxis()._minor_tick_kw['gridOn']
	plt.tight_layout(rect=[0, 0, 1, 0.95])
	if savefileroot:
		tikzplotlib.clean_figure()
		tikzplotlib.save(savefileroot + '_cov_detail.tex', extra_axis_parameters=['PlotStyle'])
	fig,axs = plt.subplots(1,2)
	axs[0].plot(cov_invmu_dlogmu[:t_max], marker='o', fillstyle='none', color='blue')
	axs[0].set_title('$Cov_{\\lambda}(-1/\\mu, d\\log \\mu)$')
	axs[1].plot(cov_invmu_dlogcosts[:t_max], marker='o', fillstyle='none', color='blue')
	axs[1].set_title('$Cov_{\\lambda}(-1/\\mu, d\\log$ Costs$)$')
	axs[0].ticklabel_format(useOffset=False, style='plain')
	axs[1].ticklabel_format(useOffset=False, style='plain')
	axs[0].get_xaxis()._gridOnMajor = axs[0].get_xaxis()._major_tick_kw['gridOn']
	axs[1].get_xaxis()._gridOnMajor = axs[1].get_xaxis()._major_tick_kw['gridOn']
	axs[0].get_xaxis()._gridOnMinor = axs[0].get_xaxis()._minor_tick_kw['gridOn']
	axs[1].get_xaxis()._gridOnMinor = axs[1].get_xaxis()._minor_tick_kw['gridOn']
	axs[0].get_yaxis()._gridOnMajor = axs[0].get_yaxis()._major_tick_kw['gridOn']
	axs[1].get_yaxis()._gridOnMajor = axs[1].get_yaxis()._major_tick_kw['gridOn']
	axs[0].get_yaxis()._gridOnMinor = axs[0].get_yaxis()._minor_tick_kw['gridOn']
	axs[1].get_yaxis()._gridOnMinor = axs[1].get_yaxis()._minor_tick_kw['gridOn']
	plt.tight_layout(rect=[0, 0, 1, 0.95])
	if savefileroot:
		tikzplotlib.clean_figure()
		tikzplotlib.save(savefileroot + '_cov_detail_invmu.tex', extra_axis_parameters=['PlotStyle'])
	dlogstd_mu = np.zeros(periods)
	logmu_flat = np.log(np.reshape(markups,(-2,)))
	for i in range(periods):
		dlogstd_mu[i] = sp.weighted_cov(logmu_flat, dmu_theta[:,i], np.reshape(z,(-2,)))/sp.weighted_cov(logmu_flat, logmu_flat, np.reshape(z,(-2,)))
	firm_pcts_to_plot = [0, 0.25, 0.5, 0.75, 0.9, 0.96, 0.98, 1]
	sales_pcts_to_plot = [0.1, 0.15, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.75, 0.8, 0.9, 1]
	sales_pcts_to_plot = [0.2, 0.5, 0.7, 0.9, 1.0]
	firm_pcts_to_plot = np.searchsorted(cum_shares, sales_pcts_to_plot) / len(cum_shares)
	plot_colors = pl.cm.viridis(np.linspace(0, 1, len(firm_pcts_to_plot))[::-1])
	for ind,i in enumerate(firm_pcts_to_plot):
		label = int(100 * i)
		label = int(100 * sales_pcts_to_plot[ind])
		plt.plot(dlambda_theta[int(i * (len(shares)-1))][:15], marker='o', fillstyle='none', label='$d\\log \\lambda_{' + str(label) + '}$', color=plot_colors[ind])
		plt.title('Impulse response: Sales shares')
		plt.legend()
	#savefileroot = '../Draft/figures/belgian_dynamics'
	if savefileroot is not None:
		plt.title(None)
		tikzplotlib.clean_figure()
		tikzplotlib.save(savefileroot + '_salesshares.tex', extra_axis_parameters=['PlotStyle'])
	return dlogstd_mu

def get_var_display_names(var_list):
	DISPLAY_NAME_MASTER = {
		'i_nom' : 'Quarterly interest rate ($d\\log i_t$)', 
		'delta' : '$d \\log(\\bar{\\delta}_t/Y)$', 
		'mu' : '$E_{\\lambda}[E^{\\theta}[d \\log \\mu_{\\theta, t}]]$', 
		'A' : 'Aggregate TFP ($d \\log A_t$)', 
		'Y' : 'Output ($d \\log Y_t$)', 
		'w' : 'Nominal wage ($d \\log w_t$)', 
		'Lambda' : '$d \\log \\Lambda_{L,t}$', 
		'L' : 'Labor ($d \\log L_t$)', 
		'p' : 'CPI ($d\\log P^Y_t$)',
		'dpi' : 'Inflation ($d\\log \\pi_t$)',
		'v' : 'Monetary shock ($v_t$)',
		'r' : 'Real interest rate ($r_t$)',
		'std_mu' : 'TFPR dispersion ($d\\log Std([\\log \\mu_{\\theta,t}])$)',
		'M' : 'Money Supply ($d\\log M_t$)'
	}
	display_names = []
	for var in var_list:
		display_names.append(DISPLAY_NAME_MASTER[var])
	return display_names

def calculate_ir_stats(impulse_res, impulse_res_homog, impulse_res_CES, Y_index):
	def find_index_in_decr_arr(arr, val):
		z = len(arr) - np.searchsorted(arr[::-1],val)
		lamb = (val-arr[z-1])/(arr[z]-arr[z-1])
		return z-1+lamb
	impulse_res_list = [impulse_res, impulse_res_homog, impulse_res_CES]
	names = ['Heterogeneous', 'Homogeneous', 'CES']
	base_pres = None
	base_loss = None
	base_effect0 = None
	half_lives = [0.5, 0.25, 0.125, 0.0625]
	for i in range(len(impulse_res_list))[::-1]:
		print(names[i] + ' -------------------------')
		pers = impulse_res_list[i][Y_index]/impulse_res_list[i][Y_index][0]
		pers_vals = np.array([find_index_in_decr_arr(pers, val) for val in half_lives])
		print('Persistence: ', pers_vals)
		if names[i]=='CES':
			base_pres = pers_vals
		else:
			print('Increased persistence over base: ', pers_vals / base_pres)
		print('Output effect t=0: ', impulse_res_list[i][Y_index][0])
		if names[i]=='CES':
			base_effect0 = impulse_res_list[i][Y_index][0]
		else:
			print('Increased output effect t=0: ', impulse_res_list[i][Y_index][0] / base_effect0)
		cumulative_loss_arr = impulse_res_list[i][Y_index].cumsum()
		length = len(cumulative_loss_arr)
		cumulative_loss = cumulative_loss_arr[int(length/2):int(4*length/5)].mean()
		print('Cumulative loss: ', cumulative_loss)
		if names[i]=='CES':
			base_loss = cumulative_loss
		else:
			print('Increased cum. loss over base: ', cumulative_loss / base_loss)
		

	# pers_het = impulse_res[Y_index]/impulse_res[Y_index][0]
	# pers_hom = impulse_res_homog[Y_index]/impulse_res_homog[Y_index][0]
	# pers_ces = impulse_res_CES[Y_index]/impulse_res_CES[Y_index][0]
	# for val in [0.5, 0.25, 0.125, 0.0625]:
	# 	print('Heterogeneous: ', val, find_index_in_decr_arr(pers_het, val))
	# 	print('Homogeneous: ', val, find_index_in_decr_arr(pers_hom, val))
	# 	print('CES: ', val, find_index_in_decr_arr(pers_ces, val))

	# cumulative_loss_het = impulse_res[Y_index].cumsum() - impulse_res[Y_index][0]/2
	# cumulative_loss_hom = impulse_res_homog[Y_index].cumsum() - impulse_res_homog[Y_index][0]/2
	# cumulative_loss_ces = impulse_res_CES[Y_index].cumsum() - impulse_res_CES[Y_index][0]/2

# Plots impulse responses for heterogenous firms, homog. firms, and CES versions
def plot_all_impulse_responses(shock_function, var_list, suptitle_base, shares, markups, passthroughs, periods=50, savefileroot=None, shock_type='interest_rate'):
	t_max = 15
	# Calculate all sufficient stats
	elast = 1/(1 - (1/markups))
	sharemu = shares/markups
	E_elast = sp.weighted_exp(elast, shares)
	E_rho = sp.weighted_exp(passthroughs, shares)
	E_elastrho = sp.weighted_exp(elast * passthroughs, shares)
	E_mu_elastrho = sp.weighted_exp(elast * passthroughs, sharemu)
	E_mu_elast = sp.weighted_exp(elast, sharemu)
	# No heterogeneity
	E_elast_homog = sp.weighted_exp(elast, shares)
	E_rho_homog = sp.weighted_exp(passthroughs, shares)
	#print('E_ELAST_HOMOG', E_elast_homog)
	impulse_res = shock_function(E_elast, E_rho, E_elastrho, E_mu_elast, E_mu_elastrho, periods=periods)
	impulse_res_homog = shock_function(E_elast_homog, E_rho_homog, E_elast_homog * E_rho_homog, E_elast_homog, E_elast_homog * E_rho_homog, periods=periods)
	impulse_res_CES = shock_function(E_elast_homog, 1, E_elast_homog, E_elast_homog, E_elast_homog, periods=periods)
	impulse_res = list(impulse_res)
	impulse_res_homog = list(impulse_res_homog)
	impulse_res_CES = list(impulse_res_CES)
	impulse_res.append(augment_impulse_res(impulse_res, var_list, shares, markups, passthroughs, elast, periods=periods, t_max=t_max, savefileroot=savefileroot))
	impulse_res_homog.append(np.zeros(periods))
	impulse_res_CES.append(np.zeros(periods))
	var_list.append('std_mu')
	labels = get_var_display_names(var_list)
	suptitles = [suptitle_base + ' (Heterogenous firms)', suptitle_base + ' (Homogenous firms)', suptitle_base + ' (CES)']
	filesuffix = ['heterogenous', 'homogenous', 'ces']
	# for i,var in enumerate(impulse_res):
	# 	plt.figure()
	# 	plt.plot(impulse_res_CES[i][:t_max], marker='o', color='green', fillstyle='none')
	# 	plt.plot(impulse_res_homog[i][:t_max], marker='o', color='darkorange', fillstyle='none')
	# 	plt.plot(impulse_res[i][:t_max], marker='o', color='blue', fillstyle='none')
	# 	plt.gca().ticklabel_format(useOffset=False, style='plain')
	# 	plt.title(labels[i])
	# 	if savefileroot:
	# 		tikzplotlib.clean_figure()
	# 		plt.gca().get_xaxis()._gridOnMajor = plt.gca().get_xaxis()._major_tick_kw['gridOn']
	# 		plt.gca().get_xaxis()._gridOnMinor = plt.gca().get_xaxis()._minor_tick_kw['gridOn']
	# 		plt.gca().get_yaxis()._gridOnMajor = plt.gca().get_yaxis()._major_tick_kw['gridOn']
	# 		plt.gca().get_yaxis()._gridOnMinor = plt.gca().get_yaxis()._minor_tick_kw['gridOn']
	# 		tikzplotlib.save(savefileroot + '_' + var_list[i] + '.tex', extra_axis_parameters=['IRFStyle'])
	if shock_type=='interest_rate':
		mod_var_list = ['v', 'p', 'L', 'Y', 'A', 'std_mu']
		mod_var_list2 = ['i_nom', 'dpi']
	else:
		mod_var_list = ['M', 'i_nom', 'p', 'dpi', 'A', 'Y', 'L', 'std_mu']
		mod_var_list2 = None
	fig,axs = plt.subplots(int((len(mod_var_list) + 1)/2),2,figsize=(8,1.5*int((len(var_list) + 1)/2)))
	axs = axs.flatten()
	for plot_ind,var in enumerate(mod_var_list):
		i = var_list.index(var)
		axs[plot_ind].plot(impulse_res_CES[i][:t_max], marker='o', color='green', fillstyle='none')
		axs[plot_ind].plot(impulse_res_homog[i][:t_max], marker='o', color='darkorange', fillstyle='none')
		axs[plot_ind].plot(impulse_res[i][:t_max], marker='o', color='blue', fillstyle='none')
		axs[plot_ind].set_title(labels[i])
		axs[plot_ind].get_xaxis()._gridOnMajor = axs[plot_ind].get_xaxis()._major_tick_kw['gridOn']
		axs[plot_ind].get_xaxis()._gridOnMinor = axs[plot_ind].get_xaxis()._minor_tick_kw['gridOn']
		axs[plot_ind].get_yaxis()._gridOnMajor = axs[plot_ind].get_yaxis()._major_tick_kw['gridOn']
		axs[plot_ind].get_yaxis()._gridOnMinor = axs[plot_ind].get_yaxis()._minor_tick_kw['gridOn']
	plt.tight_layout(rect=[0, 0, 1, 0.95])
	if savefileroot:
	 	plt.suptitle(None)
	 	tikzplotlib.clean_figure()
	 	tikzplotlib.save(savefileroot + '_full.tex', extra_groupstyle_parameters=['horizontal sep=0.5in, vertical sep=0.7in'], extra_axis_parameters=['IRFStyle'])
	if mod_var_list2 is not None:
		fig,axs = plt.subplots(1,2)
		axs = axs.flatten()
		for plot_ind,var in enumerate(mod_var_list2):
			i = var_list.index(var)
			axs[plot_ind].plot(impulse_res_CES[i][:t_max], marker='o', color='green', fillstyle='none')
			axs[plot_ind].plot(impulse_res_homog[i][:t_max], marker='o', color='darkorange', fillstyle='none')
			axs[plot_ind].plot(impulse_res[i][:t_max], marker='o', color='blue', fillstyle='none')
			axs[plot_ind].set_title(labels[i])
			axs[plot_ind].get_xaxis()._gridOnMajor = axs[plot_ind].get_xaxis()._major_tick_kw['gridOn']
			axs[plot_ind].get_xaxis()._gridOnMinor = axs[plot_ind].get_xaxis()._minor_tick_kw['gridOn']
			axs[plot_ind].get_yaxis()._gridOnMajor = axs[plot_ind].get_yaxis()._major_tick_kw['gridOn']
			axs[plot_ind].get_yaxis()._gridOnMinor = axs[plot_ind].get_yaxis()._minor_tick_kw['gridOn']
		plt.tight_layout(rect=[0, 0, 1, 0.95])
		if savefileroot:
		 	plt.suptitle(None)
		 	tikzplotlib.clean_figure()
		 	tikzplotlib.save(savefileroot + '_extra.tex', extra_groupstyle_parameters=['horizontal sep=0.5in, vertical sep=0.7in'], extra_axis_parameters=['IRFStyle'])
	Y_index = var_list.index('Y')
	Pi_index = var_list.index('dpi')
	print(impulse_res[var_list.index('i_nom')][:20])
	fig,axs = plt.subplots(1,2)
	axs[0].plot(impulse_res[Y_index][:t_max], marker='o', fillstyle='none', color='blue', label='Heterogenous firms')
	axs[0].plot(impulse_res_homog[Y_index][:t_max], marker='o', fillstyle='none', color='darkorange', label='Homogenous firms')
	axs[0].plot(impulse_res_CES[Y_index][:t_max], marker='o', fillstyle='none', color='green', label='CES')
	#axs[0].set_ylabel('$d\\log Y_t$')
	axs[0].set_title('$d\\log Y_t$')
	axs[1].plot(impulse_res[Y_index][:t_max] / impulse_res_homog[Y_index][:t_max], marker='o', fillstyle='none', color='blue', label='Misallocation channel')
	axs[1].plot(impulse_res_homog[Y_index][:t_max] / impulse_res_CES[Y_index][:t_max], marker='o', fillstyle='none', color='darkorange', linestyle='--', label='Real rigidities')
	#axs[1].set_ylim((0,2))
	#axs[1].set_ylabel('Flattening from channel')
	axs[1].set_title('Flattening from channel')
	plt.suptitle(suptitle_base + ': $d\\log Y_t$')
	plt.tight_layout(rect=[0, 0, 1, 0.95])
	axs[0].ticklabel_format(useOffset=False, style='plain')
	axs[1].ticklabel_format(useOffset=False, style='plain')
	if savefileroot:
		plt.suptitle(None)
		tikzplotlib.clean_figure()
		axs[0].get_xaxis()._gridOnMajor = axs[0].get_xaxis()._major_tick_kw['gridOn']
		axs[1].get_xaxis()._gridOnMajor = axs[1].get_xaxis()._major_tick_kw['gridOn']
		axs[0].get_xaxis()._gridOnMinor = axs[0].get_xaxis()._minor_tick_kw['gridOn']
		axs[1].get_xaxis()._gridOnMinor = axs[1].get_xaxis()._minor_tick_kw['gridOn']
		axs[0].get_yaxis()._gridOnMajor = axs[0].get_yaxis()._major_tick_kw['gridOn']
		axs[1].get_yaxis()._gridOnMajor = axs[1].get_yaxis()._major_tick_kw['gridOn']
		axs[0].get_yaxis()._gridOnMinor = axs[0].get_yaxis()._minor_tick_kw['gridOn']
		axs[1].get_yaxis()._gridOnMinor = axs[1].get_yaxis()._minor_tick_kw['gridOn']
		tikzplotlib.save(savefileroot + '_dlogYdetail.tex', extra_axis_parameters=['PlotStyle'])
	### SIMPLE VERSION FOR VOXEU
	fig,axs = plt.subplots(1,2)
	axs[0].plot(impulse_res[Y_index][:t_max], marker='o', fillstyle='none', color='blue', label='Supply-side effect included')
	axs[0].plot(impulse_res_homog[Y_index][:t_max], marker='o', fillstyle='none', color='darkorange', label='Real rigidities included')
	axs[0].plot(impulse_res_CES[Y_index][:t_max], marker='o', fillstyle='none', color='green', label='Standard model')
	axs[0].legend()
	#axs[0].set_ylabel('$d\\log Y_t$')
	axs[0].set_title('Impact on output ($d\\log Y_t$)')
	axs[1].plot(impulse_res[Y_index][:t_max] / impulse_res_homog[Y_index][:t_max], marker='o', fillstyle='none', color='blue', label='Supply-side effect')
	axs[1].plot(impulse_res_homog[Y_index][:t_max] / impulse_res_CES[Y_index][:t_max], marker='o', fillstyle='none', color='darkorange', linestyle='--', label='Real rigidities')
	#axs[1].set_ylim((0,2))
	#axs[1].set_ylabel('Flattening from channel')
	axs[1].set_title('Flattening from channel')
	axs[1].legend()
	axs[0].ticklabel_format(useOffset=False, style='plain')
	axs[1].ticklabel_format(useOffset=False, style='plain')
	plt.tight_layout(rect=[0, 0, 1, 0.95])
	### VOXEU TFPR and TFP graphs
	fig,axs = plt.subplots(1,2)
	axs[0].plot(impulse_res_CES[var_list.index('A')][:t_max], marker='o', fillstyle='none', color='green', label='Standard model')
	axs[0].plot(impulse_res_homog[var_list.index('A')][:t_max], marker='o', fillstyle='none', color='darkorange', label='Real rigidities included')
	axs[0].plot(impulse_res[var_list.index('A')][:t_max], marker='o', fillstyle='none', color='blue', label='Supply-side effect included')
	axs[0].legend()
	axs[0].set_title('Aggregate TFP ($d\\log A_t$)')
	axs[1].plot(impulse_res_CES[var_list.index('std_mu')][:t_max], marker='o', fillstyle='none', color='green', label='Standard model')
	axs[1].plot(impulse_res_homog[var_list.index('std_mu')][:t_max], marker='o', fillstyle='none', color='darkorange', label='Real rigidities included')
	axs[1].plot(impulse_res[var_list.index('std_mu')][:t_max], marker='o', fillstyle='none', color='blue', label='Supply-side effect included')
	#axs[1].legend()
	axs[1].set_title('TFPR dispersion ($d\\log Std(TFPR)$)')
	axs[0].ticklabel_format(useOffset=False, style='plain')
	axs[1].ticklabel_format(useOffset=False, style='plain')
	plt.tight_layout(rect=[0, 0, 1, 0.95])
	if savefileroot:
		tikzplotlib.clean_figure()
		axs[0].get_xaxis()._gridOnMajor = axs[0].get_xaxis()._major_tick_kw['gridOn']
		axs[1].get_xaxis()._gridOnMajor = axs[1].get_xaxis()._major_tick_kw['gridOn']
		axs[0].get_xaxis()._gridOnMinor = axs[0].get_xaxis()._minor_tick_kw['gridOn']
		axs[1].get_xaxis()._gridOnMinor = axs[1].get_xaxis()._minor_tick_kw['gridOn']
		axs[0].get_yaxis()._gridOnMajor = axs[0].get_yaxis()._major_tick_kw['gridOn']
		axs[1].get_yaxis()._gridOnMajor = axs[1].get_yaxis()._major_tick_kw['gridOn']
		axs[0].get_yaxis()._gridOnMinor = axs[0].get_yaxis()._minor_tick_kw['gridOn']
		axs[1].get_yaxis()._gridOnMinor = axs[1].get_yaxis()._minor_tick_kw['gridOn']
		tikzplotlib.save(savefileroot + '_TFP_TFPR.tex', extra_axis_parameters=['PlotStyle'])
	# Plot persistence graphs
	fig,axs = plt.subplots(1,2)
	axs[0].plot(np.log(0.5) / np.log(np.roll(impulse_res_CES[Y_index],-1)/impulse_res_CES[Y_index])[:2*t_max], color='green', label='CES')
	axs[0].plot(np.log(0.5) / np.log(np.roll(impulse_res_homog[Y_index],-1)/impulse_res_homog[Y_index])[:2*t_max], color='darkorange', label='Homogenous firms')
	axs[0].plot(np.log(0.5) / np.log(np.roll(impulse_res[Y_index],-1)/impulse_res[Y_index])[:2*t_max], color='blue', label='Heterogenous firms')
	print(np.log(np.roll(impulse_res_CES[Y_index],-1)/impulse_res_CES[Y_index])[:2*t_max] / np.log(np.roll(impulse_res[Y_index],-1)/impulse_res[Y_index])[:2*t_max])
	axs[0].set_title('$d\\log Y_t$')
	axs[0].set_ylabel('Instantaneous half-life')
	axs[0].legend()
	axs[1].plot(np.log(0.5) / np.log(np.roll(impulse_res_CES[Pi_index],-1)/impulse_res_CES[Pi_index])[:2*t_max], color='green', label='CES')
	axs[1].plot(np.log(0.5) / np.log(np.roll(impulse_res_homog[Pi_index],-1)/impulse_res_homog[Pi_index])[:2*t_max], color='darkorange', label='Homogenous firms')
	axs[1].plot(np.log(0.5) / np.log(np.roll(impulse_res[Pi_index],-1)/impulse_res[Pi_index])[:2*t_max], color='blue', label='Heterogenous firms')
	print(np.log(np.roll(impulse_res_CES[Pi_index],-1)/impulse_res_CES[Pi_index])[:2*t_max] / np.log(np.roll(impulse_res[Pi_index],-1)/impulse_res[Pi_index])[:2*t_max])
	axs[1].set_title('$d\\log \\pi_t$')
	if savefileroot:
		plt.suptitle(None)
		tikzplotlib.clean_figure()
		axs[0].get_xaxis()._gridOnMajor = axs[0].get_xaxis()._major_tick_kw['gridOn']
		axs[1].get_xaxis()._gridOnMajor = axs[1].get_xaxis()._major_tick_kw['gridOn']
		axs[0].get_xaxis()._gridOnMinor = axs[0].get_xaxis()._minor_tick_kw['gridOn']
		axs[1].get_xaxis()._gridOnMinor = axs[1].get_xaxis()._minor_tick_kw['gridOn']
		axs[0].get_yaxis()._gridOnMajor = axs[0].get_yaxis()._major_tick_kw['gridOn']
		axs[1].get_yaxis()._gridOnMajor = axs[1].get_yaxis()._major_tick_kw['gridOn']
		axs[0].get_yaxis()._gridOnMinor = axs[0].get_yaxis()._minor_tick_kw['gridOn']
		axs[1].get_yaxis()._gridOnMinor = axs[1].get_yaxis()._minor_tick_kw['gridOn']
		tikzplotlib.save(savefileroot + '_dlogYhalflife.tex', extra_axis_parameters=['PlotStyle'])
	calculate_ir_stats(impulse_res, impulse_res_homog, impulse_res_CES, Y_index)

def driver():
	# Values from Kimball data
	shares,markups,passthroughs,prod,output = sp.get_key_darwinian_data()
	markups = sp.generate_markups(passthroughs, shares, 1.15)

	# To run Klenow-Willis instead
	# shares,markups,passthroughs = kw.generate_kw_distributions(superelast=1.6)
	# plot_all_impulse_responses(monetary_policy_impulse_responses, ['v','delta','mu','A','Y','Lambda','L','w','i_nom','r','p', 'dpi', 'M'], 'Response to exogenous monetary policy shock', shares, markups, passthroughs, periods=1000, savefileroot='../Draft/figures/klenow_willis/belgian_dynamics_kw')

	plot_all_impulse_responses(money_supply_impulse_responses_cia, ['M','delta','mu','A','Y','Lambda','L','w','i_nom','r','p', 'dpi'], 'Response to exogenous monetary policy shock', shares, markups, passthroughs, periods=1000, savefileroot=None, shock_type='money')

	plot_all_impulse_responses(monetary_policy_impulse_responses, ['v','delta','mu','A','Y','Lambda','L','w','i_nom','r','p', 'dpi', 'M'], 'Response to exogenous monetary policy shock', shares, markups, passthroughs, periods=1000, savefileroot=None, shock_type='interest_rate')

	plt.show()

driver()
